A simple example of a neural network written in Octave/Matlab
In [44]:
function y = sigmoid(x, derivative=false)
if (derivative)
y = x.*(1-x);
y = 1.0 ./ (1.0 + exp(-x));
In [45]:
function theta = theta_init(in_size, out_size, epsilon = 0.12)
theta = rand(out_size, in_size +1) * 2 * epsilon - epsilon;
In [46]:
function [theta1, theta2] = nn_train(X, y, desired_error, max_iterations = 100000, epsilon = 0.12, hidden_nodes = 0)
m = size(X, 1);
input_nodes = size(X, 2);
output_nodes = size(y, 2);
if (hidden_nodes <= 0)
hidden_nodes = floor(input_nodes * 2 / 3 + output_nodes);
theta1 = theta_init(input_nodes, hidden_nodes, epsilon)';
theta2 = theta_init(hidden_nodes, output_nodes, epsilon)';
% Move constants outside of the loop
% The first activation layer is constant
a1 = [ones(size(X, 1), 1) X];
% The bias unit ones are constant too
a2_ones = ones(size(a1, 1), 1);
printf("Training the neural network (%d input, %d hidden, %d output nodes) with %d observations\n", ...
input_nodes, hidden_nodes, output_nodes, m);
tic_id = tic();
for k = 1:max_iterations
% Feed forward
a2 = [a2_ones sigmoid( a1 * theta1 )];
a3 = sigmoid( a2 * theta2 );
a3_delta = y - a3;
% Each second report the current state to the user
if (toc(tic_id) > 1)
meansq_error = mean(meansq(a3_delta));
printf("Iteration: %9d (max:%d), mse: %9f (target:%f)\n", ...
k, max_iterations, meansq_error, desired_error);
tic_id = tic();
if (meansq_error < desired_error)
% Backpropagation
a2_error = a3_delta * theta2';
a2_delta = a2_error .* sigmoid(a2, true);
theta2 += ((a2' * a3_delta) ./ m);
theta1 += ((a1' * a2_delta) ./ m)(:, 2:end);
In [47]:
function a3 = nn_predict(X, theta1, theta2)
a2 = sigmoid([ones(size(X, 1), 1) X] * theta1);
a3 = sigmoid([ones(size(X, 1), 1) a2] * theta2);
In [48]:
X = [0 0; 0 1; 1 0; 1 1];
y = [0; 1; 1; 0];
[theta1, theta2] = nn_train(X, y, 0.0001);
pred_values = nn_predict(X, theta1, theta2);
printf("\n\n Input Values Predicted Actual\n");
disp([X pred_values y])
printf("\nMean square error of trained model predictions: %f\n", mean(meansq(y - pred_values)))